""""
Copyright (c) 2024 [XKW.Beijing]
All rights reserved.

Author: [tangxiaojun]
Email: [417281862@qq.com]
"""

from flask import Flask, jsonify, request
from werkzeug.exceptions import HTTPException

"""
导入自定义中间件
"""
from middleware_timing import TimingMiddleware
from middleware_auth import HttpBasicAuthMiddleware

"""
导入业务层
"""
from resource_user import UserResource
from resource_bill import BillResource
from resource_nlp import NlpResource
from resource_cczh import CczhResource
from resource_ltp import LtpResource
from resource_glm import GlmResource
from common_exception import register_error_handlers

app = Flask(__name__)

"""
注册错误处理器
"""
register_error_handlers(app)

"""
注册中间件
"""
middleware_timing = TimingMiddleware()
middleware_timing.init_app(app)

middleware_auth = HttpBasicAuthMiddleware()
middleware_auth.init_app(app)

"""
注册业务层
"""
user_resource = UserResource()
bill_resource = BillResource()
npl_resource = NlpResource()
cczh_resource = CczhResource()
ltp_resource = LtpResource()
glm_resource = GlmResource()

"""
加载数据
"""
cczh_resource.load_model();

"""
-------------------- User API --------------------
"""


@app.route('/user/info/<user_id>', methods=['GET'])
def user_info(user_id):
    return jsonify(user_resource.get(user_id))


"""
--------------------Bill API --------------------
"""


@app.route('/bill/info', methods=['POST'])
def bill_info():
    return jsonify(bill_resource.post())


"""
-------------------- NLP API --------------------
"""


@app.route('/nlp/chinese-classification/predict', methods=['GET'])
def nlp_chinese_classification():
    return jsonify(npl_resource.predict_model_list())


@app.route('/nlp/chinese-classification-text', methods=['POST'])
def nlp_chinese_classification_text():
    print(request.args)
    text = request.args.get('text')  # 从查询参数中获取 text
    if not text:
        return jsonify(error='Missing required parameter: text'), 400
    return jsonify(npl_resource.predict_model_by_text(text))


@app.route('/nlp/fill-mask', methods=['POST'])
def nlp_fill_mask():
    text = request.args.get('text')
    if not text:
        return jsonify(error='Missing required parameter: text'), 400
    return jsonify(npl_resource.fill_mask(text))

@app.route('/nlp/qa', methods=['POST'])
def nlp_qa():
    question = request.args.get('question')
    context = request.args.get('context')
    if not question:
        return jsonify(error='Missing required parameter: question'), 400
    if not context:
        return jsonify(error='Missing required parameter: context'), 400
    return jsonify(npl_resource.qa_predict(question, context))


"""
-------------------CCZH API --------------------
"""


@app.route('/cczh/get_nearest_neighbors', methods=['POST'])
def get_nearest_neighbors():
    text = request.args.get('predictions_text')  # 从查询参数中获取 text
    if not text:
        return jsonify(error='Missing required parameter: predictions_text'), 400
    return jsonify(cczh_resource.get_nearest_neighbors(text))


"""
--------------------LTP API --------------------
"""


@app.route('/ltp/get_ner', methods=['POST'])
def get_ner():
    text = request.args.get('predictions_text')  # 从查询参数中获取 text
    if not text:
        return jsonify(error='Missing required parameter: predictions_text'), 400
    return jsonify(ltp_resource.predict_model_by_text(text))


@app.route('/ltp/predict_words_ltp', methods=['POST'])
def predict_words_ltp():
    text = request.args.get('predictions_text')  # 从查询参数中获取 text
    if not text:
        return jsonify(error='Missing required parameter: predictions_text'), 400
    return jsonify(ltp_resource.predict_words_ltp(text))


"""
--------------------GLM API --------------------
"""

if __name__ == '__main__':
    app.run(port=5000)
